import torch.nn as nn
import dgl
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader
import math
from generative_model.generative_layers import CDRVAE,LightGCN_gen
import torch


class ReproduceDataset(Dataset):
    def __init__(self, tensor1, tensor2, batch_size):
        """
        Args:
            tensor1 (torch.Tensor): overlap users的id
            tensor2 (torch.Tensor): 将边id随机分成多个batch
            batch_size (int): 第一组数据的 batch_size
        """
        assert len(tensor2) >= len(tensor1), "tensor2 的长度必须大于或等于 tensor1 的长度！"
        self.tensor1 = tensor1
        self.tensor2 = tensor2
        self.batch_size = batch_size
        self.num_batches = (len(tensor1) + batch_size - 1) // batch_size  # 第一组数据的批次数
        self.shuffle_indices = torch.arange(len(self.tensor1))  # 初始化索引
    def shuffle(self):
        self.shuffle_indices = torch.randperm(len(self.tensor1))
    def __len__(self):
        return self.num_batches
    def __getitem__(self, idx):
        # 使用打乱后的索引
        # 将索引打乱，然后分batch
        shuffled_tensor1 = self.tensor1[self.shuffle_indices]
        # 第一组数据的索引范围
        start1 = idx * self.batch_size
        end1 = min((idx + 1) * self.batch_size, len(shuffled_tensor1))
        batch1 = shuffled_tensor1[start1:end1]
        # 第二组数据的索引范围，与第一组数据的 batch 大小保持一致
        start2 = idx * len(self.tensor2) // self.num_batches
        end2 = min((idx + 1) * len(self.tensor2) // self.num_batches, len(self.tensor2))
        batch2 = self.tensor2[start2:end2]
        return batch1.flatten(), batch2.flatten()
class Generator(nn.Module):
    def __init__(self,
        config,
        source_u,
        source_i,
        target_u,
        target_i,
        total_num_users,
        total_num_items,
        overlapped_num_users,
        source_num_users,
        source_num_items,
        target_num_users,
        target_num_items
                 ):
        super().__init__()
        self.config = config
        self.device = config['device']
        self.total_num_users=total_num_users
        self.total_num_items=total_num_items
        self.overlapped_num_users=overlapped_num_users
        self.source_num_users=source_num_users
        self.source_num_items=source_num_items
        self.target_num_users=target_num_users
        self.target_num_items=target_num_items
        self.source_u = source_u
        self.source_i = source_i
        self.target_u = target_u
        self.target_i = target_i
        self.user_inter = torch.cat((source_u, target_u), dim=0)
        self.item_inter = torch.cat((source_i, target_i), dim=0)
        self.interaction_graph=self.create_graph(self.user_inter,
                                                 self.item_inter).to(self.device)
        self.user_embedding = torch.nn.Embedding(num_embeddings=self.total_num_users,
                                                 embedding_dim=config['embedding_size'])
        self.item_embedding = torch.nn.Embedding(num_embeddings=self.total_num_items,
                                                 embedding_dim=config['embedding_size'])
        self.n_layers=config['generate_layers']
        # mask 多次。提取子图mask来进行
        if config['generative_model']=='VAE':
           self.generater=CDRVAE(self.config,
                              self.total_num_users,
                              self.total_num_items)
        else:
            self.generater = LightGCN_gen(self.config,
                                    self.total_num_users,
                                    self.total_num_items)
        self.overlap_users_data = torch.arange(1, self.overlapped_num_users + 1)[torch.randperm(self.overlapped_num_users)]
        # 随机找一些存在的边
        self.exist_edeges = torch.arange(self.user_inter.shape[0])[torch.randperm(self.user_inter.shape[0])]
        self.Dataset=ReproduceDataset(self.overlap_users_data.to(self.device),
                                      self.exist_edeges.to(self.device),
                                      self.config['mask_num_overlap_users'])
        # 每次的个batch是一一些想要mask掉的user的id
        self.OverlapUserDataloader=DataLoader(self.Dataset,
                                              batch_size=1,
                                              shuffle=False)
    def create_graph(self,src,dst):
        # 构建双向图并添加自环
        g = dgl.graph((src,dst), num_nodes=self.total_num_users + self.total_num_items)
        g = dgl.to_bidirected(g)
        g = dgl.add_self_loop(g)
        g = g.to('cuda')
        return g
    def split_overlap_users_n_groups(self, mask_num_overlap_users):
        # 输出是N*batch_size的需要mask掉的users的id，以及
        # 创建并打乱 0 到 x 的序列
        x=self.overlapped_num_users
        arr = torch.arange(x)[torch.randperm(x)]
        # 计算 reshape 所需的元素总数(向上取整到 n 的倍数)
        total_elements = x
        rows = math.ceil(x / mask_num_overlap_users)
        needed = rows * mask_num_overlap_users
        # 如果不足，则填充随机数来凑够 needed 个元素
        if needed > total_elements:
            pad_size = needed - total_elements
            # 在 [0, x] 范围内随机填充 pad_size 个数
            pads = torch.randint(low=0, high=x + 1, size=(pad_size,))
            arr = torch.cat([arr, pads])
        arr = arr.reshape(-1,mask_num_overlap_users)
        return arr
    def _random_mask_edges(self,
                           tensor: torch.Tensor,
                           n: int) -> torch.Tensor:
        """
        随机选择张量中 `n` 个 `False` 位置并将其设为 `True`，保持原来 `True` 的位置不变。
        参数：
        tensor : torch.Tensor一个布尔类型的张量。
        n : int需要设为 `True` 的 `False` 位置的数量。
        返回：
        torch.Tensor
            修改后的张量，其中随机选择的 `n` 个 `False` 位置被设为 `True`。
        """
        # 找到所有为 False 的位置的扁平索引
        false_indices = torch.nonzero(~tensor, as_tuple=False).squeeze()
        num_false = false_indices.numel()
        # 随机选择 n 个 False 的扁平索引
        selected_indices = false_indices[torch.randperm(num_false)[:n]]
        tensor[selected_indices] = True
        return tensor
        # 产生一组mask掉的matrix和对应的labels
    def mask_interactions(self,masked_over_u_idx):
        # 输出需要mask掉的指定的边
        # 一部分是overlap_user的边，另一部分是随机mask掉一定数量的边
        # 全部的source interactions中的user节点
        u=self.source_u.to(self.device)
        # 将masked_over_u_idx中的users在u中mask掉。u中，在masked_over_u_idx中存在的是True，否则是False
        overlap_user_mask = torch.isin(u,masked_over_u_idx)
        # 总的user长度长于source_u，因此需要padding
        num_to_pad = self.user_inter.size(0)-overlap_user_mask.size(0)
        # 如果是target domain的，则不去动（全是False，0）
        padding = torch.zeros(num_to_pad, dtype=torch.bool, device=overlap_user_mask.device)
        edge_mask = torch.cat((overlap_user_mask, padding), dim=0)
        masked_edges_id = torch.where(edge_mask)[0]
        return masked_edges_id.to(self.device)
    def get_ego_embeddings(self):
        user_embeddings = self.user_embedding.weight
        item_embeddings = self.item_embedding.weight
        ego_embeddings = torch.cat([user_embeddings, item_embeddings], dim=0)
        return ego_embeddings
    def calculate_loss(self,mask_u_idx,exist_edges):
        # mask_u_idx: 想要在overlap user中mask掉的user的id，下面的label就是被mask掉的边
        # 从overlap_u的id得到边的id
        masked_edges_id=self.mask_interactions(mask_u_idx)
        # 要训练model去预测的边
        label_edges_id=torch.unique(torch.cat((masked_edges_id, exist_edges), dim=0),sorted=False)
        all_embeddings=self.get_ego_embeddings()
        label_u = self.user_inter.to(self.device)
        label_i = self.item_inter.to(self.device)
        loss=self.generater.forward(all_embeddings=all_embeddings,
                                    g=self.interaction_graph,
                                    mask_g=dgl.remove_edges(self.interaction_graph,masked_edges_id),
                                    label_u=label_u[label_edges_id],
                                    label_i=label_i[label_edges_id],
                                    mode='train')
        return loss
    def _get_topk_items4each_user(self,
                               A: torch.Tensor,
                               B: torch.Tensor,
                               k: int,
                               chunk_size: int = 1024):
        """
        在内存友好的前提下，获取 (A x B^T) 每行的 top-k 索引。
        通过分块的方式计算，避免一次性生成 (a, b) 过大的结果张量。
        参数：
        -------
        A : (a, dim) 的张量
        B : (b, dim) 的张量
        k : 每行要获取的前 k 大元素
        chunk_size : 每次处理 B 的行数（列块大小），可根据内存情况调节
        返回：
        -------
        top_indices : (a, k) 的张量，每行表示在 B 里最相似的 k 个向量索引
        """
        device = A.device
        a, dimA = A.shape
        b, dimB = B.shape
        # 用于存储当前全局 top-k 的 (a, k) 分值和索引
        # 值初始化为 -inf，索引初始化为 -1
        global_top_vals = torch.full((a, k), float('-inf'), device=device)
        global_top_inds = torch.full((a, k), -1, dtype=torch.long, device=device)
        start = 0
        while start < b:
            end = min(start + chunk_size, b)
            # 取 B 的一部分
            B_chunk = B[start:end, :]  # (chunk_len, dim)
            # 计算局部结果 A x B_chunk^T => (a, chunk_len)
            partial_result = torch.matmul(A, B_chunk.transpose(0, 1))

            # 在每行取出局部 top-k
            partial_top_vals, partial_top_inds = partial_result.topk(k, dim=1)
            # partial_top_inds 是相对于 B_chunk 的局部索引，需要加上 start 才是 B 的全局索引
            partial_top_inds += start
            # 将局部 (vals, inds) 与全局 (vals, inds) 合并
            # 先在列维度 concat => (a, 2k)
            merged_vals = torch.cat([global_top_vals, partial_top_vals], dim=1)
            merged_inds = torch.cat([global_top_inds, partial_top_inds], dim=1)
            # 在 merged 上再做一次 topk，得到新的全局 topk
            merged_top_vals, merged_top_pos = merged_vals.topk(k, dim=1)
            # merged_top_pos 是在 [0..2k) 间的索引，要映射回 merged_inds
            row_idx = torch.arange(a, device=device).unsqueeze(1)
            final_inds = merged_inds[row_idx, merged_top_pos]
            # 更新全局候选
            global_top_vals = merged_top_vals
            global_top_inds = final_inds
            start = end
        return global_top_inds

    def _get_global_topk_interactions(self,
                                      A: torch.Tensor,
                                      B: torch.Tensor,
                                      k: int=100000,
                                      chunk_size: int = 1024):
        """
        在内存友好的前提下，获取 (A x B^T) 中全局 top-k 分值的 interactions。
        通过分块的方式计算，避免一次性生成 (a, b) 过大的结果张量。
        参数：
        -------
        A : (a, dim) 的张量，表示用户向量
        B : (b, dim) 的张量，表示物品向量
        k : 要获取的前 k 大全局 interactions
        chunk_size : 每次处理 B 的行数（列块大小），可根据内存情况调节
        返回：
        -------
        top_interactions : (k, 2) 的张量，每行表示 [user_idx, item_idx]，按 score 降序排列
        """
        device = A.device
        a, dimA = A.shape
        b, dimB = B.shape
        # 用于存储当前全局 top-k 的 (k,) 分值、用户索引和物品索引
        # 值初始化为 -inf，索引初始化为 -1
        global_top_vals = torch.full((k,), float('-inf'), device=device)
        global_top_user_inds = torch.full((k,), -1, dtype=torch.long, device=device)
        global_top_item_inds = torch.full((k,), -1, dtype=torch.long, device=device)
        start = 0
        while start < b:
            end = min(start + chunk_size, b)
            # 取 B 的一部分
            B_chunk = B[start:end, :]  # (chunk_len, dim)
            chunk_len = end - start
            # 计算局部结果 A x B_chunk^T => (a, chunk_len)
            partial_result = torch.matmul(A, B_chunk.transpose(0, 1))

            # 将 partial_result 展平为 (a * chunk_len,)
            flat_partial = partial_result.view(-1)
            # 在局部展平张量上取出 top-k
            partial_top_vals, partial_top_flat_inds = flat_partial.topk(k)
            # 计算对应的 user_idx 和 item_idx
            partial_top_user_inds = partial_top_flat_inds // chunk_len
            partial_top_item_inds = (partial_top_flat_inds % chunk_len) + start

            # 将局部 (vals, user_inds, item_inds) 与全局合并
            # 先在维度上 concat => (2k,)
            merged_vals = torch.cat([global_top_vals, partial_top_vals], dim=0)
            merged_user_inds = torch.cat([global_top_user_inds, partial_top_user_inds], dim=0)
            merged_item_inds = torch.cat([global_top_item_inds, partial_top_item_inds], dim=0)

            # 在 merged 上再做一次 topk，得到新的全局 topk
            merged_top_vals, merged_top_pos = merged_vals.topk(k)
            # merged_top_pos 是在 [0..2k) 间的索引，要映射回 merged_user_inds 和 merged_item_inds
            final_user_inds = merged_user_inds[merged_top_pos]
            final_item_inds = merged_item_inds[merged_top_pos]

            # 更新全局候选
            global_top_vals = merged_top_vals
            global_top_user_inds = final_user_inds
            global_top_item_inds = final_item_inds
            start = end

        # 组合成 (k, 2) 的张量 [user_idx, item_idx]
        top_interactions = torch.stack([global_top_user_inds, global_top_item_inds], dim=1)
        return top_interactions


    def remove_zeros(self,tensor1, tensor2):
        """
        Removes elements with 0 in either tensor1 or tensor2 and their corresponding elements in the other tensor.
        Args:
            tensor1 (torch.Tensor): The first 1D tensor.
            tensor2 (torch.Tensor): The second 1D tensor of the same size as tensor1.
        Returns:
            torch.Tensor, torch.Tensor: Filtered tensor1 and tensor2.
        """
        if tensor1.size() != tensor2.size():
            raise ValueError("Both tensors must have the same size.")
        # Find indices where neither tensor1 nor tensor2 is equal to 0
        non_zero_indices = (tensor1 != 0) & (tensor2 != 0)
        # Filter both tensors
        filtered_tensor1 = tensor1[non_zero_indices]
        filtered_tensor2 = tensor2[non_zero_indices]
        return filtered_tensor1, filtered_tensor2

    def generate_edges(self):
        all_embeddings=self.get_ego_embeddings()
        generating_user_id = torch.arange(self.target_num_users+1, self.total_num_users)
        u_embeddings, i_embeddings = self.generater.forward(all_embeddings,self.interaction_graph,self.interaction_graph,generating_user_id,1,'generate')
        i_embeddings = i_embeddings[self.target_num_items+1:self.total_num_items]
        if self.config['generate_mode']=='generate4each':
           source_indexes = self._get_topk_items4each_user(u_embeddings,i_embeddings,self.config['generate_edges']).flatten()
        else:
            source_indexes = self._get_global_topk_interactions(u_embeddings, i_embeddings).flatten()
        generated_user_id = torch.repeat_interleave(generating_user_id, repeats=self.config['generate_edges'])
        re_source_u=torch.cat((self.source_u, generated_user_id.to('cpu')), dim=0)
        re_source_i=torch.cat((self.source_i, source_indexes.to('cpu')), dim=0)
        return re_source_u,re_source_i,u_embeddings, i_embeddings
    def data_reproduce(self):
        # if self.config['generate_setting']=='add':
        self.source_u,self.source_i,u_embeddings, i_embeddings=self.generate_edges()
        return self.source_u,self.source_i


def train1(config, model):
    optimizer = torch.optim.Adam(model.parameters(), lr=config['generate_learning_rate'])

    # Early stopping 参数
    previous_loss = float('inf')  # 上一次的 loss，初始化为无穷大
    patience = 2  # 检查连续两次 loss 差值
    tolerance = 0.01  # loss 差值小于 0.01 视为停止条件
    stop_counter = 0  # 计数器，记录连续满足条件的次数
    for epoch in range(config['generate_epochs']):
#        print(f"Epoch {epoch}")
        epoch_loss = 0.0  # 用于累积一个 epoch 的 loss
        num_batches = 0  # 记录批次数量
        for batch_idx, (mask_u_idx, exist_edges) in enumerate(model.OverlapUserDataloader):
            model.train()
            optimizer.zero_grad()
            loss = model.calculate_loss(mask_u_idx.flatten(), exist_edges.flatten())
            loss.backward()
            optimizer.step()
            epoch_loss += loss.item()  # 累积 loss
            num_batches += 1
        # 计算当前 epoch 的平均 loss
        avg_loss = epoch_loss / num_batches
#        print(f'Epoch {epoch}, Average Loss: {avg_loss:.6f}')
        # Early stopping 判断
        loss_diff = abs(previous_loss - avg_loss)  # 计算当前 loss 与上一次的差值
        if loss_diff < tolerance:
            stop_counter += 1
#            print(f'Loss difference ({loss_diff:.6f}) < {tolerance}, stop counter: {stop_counter}/{patience}')
            if stop_counter >= patience:
#                print(f'Early stopping triggered after {epoch + 1} epochs.')
                break
        else:
            stop_counter = 0  # 重置计数器

        previous_loss = avg_loss  # 更新 previous_loss

    return model